

# linear algebra utility functions

using LinearAlgebra


# solve cts lyap equation given schur decomposition (faster if repeated computations)
# edited from LinearALgebra source dense.jl
function lyap(
        R::StridedMatrix{T},
        Q::StridedMatrix{T},
        C::StridedMatrix{T}
    ) where {T<:LinearAlgebra.BlasFloat}
    D = -(adjoint(Q) * (C*Q))
    Y, scale = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, D)
    rmul!(Q*(Y * adjoint(Q)), inv(scale))
end

#
# # faster than findfirst but less complete
# function find_val_index(val::Float64, arr::Vector{Float64})::Int
#     idx = -1
#     for (i,v) in enumerate(arr)
#         if v==val return i; end
#     end
#     return idx
# end


# half-vectorization for square matrix
function vech(A::Array{T,2})::Vector{T} where T
    m = LinearAlgebra.checksquare(A)
    # bitshift for m*(m+1)/2
    v = Vector{T}(undef, (m*(m+1))>>1)
    k = 0
    for j = 1:m, i = j:m
        k+=1
        v[k] = A[i,j]
    end
    return v
end


##############
# matrix exponential functions

# note: for large matrices and one time computations, use exp()
function calc_exp_matrix_tau(
        tau::Float64,
        A_eigvals::AbstractVector,
        A_eigvecs::AbstractMatrix
    )::Matrix{Float64}
    # exp(-A*tau) given eigvals and eigvecs of A
    # diag exp(-lami*tau) (uses Diagonal matrix type for faster computations)
    exp_eigvals = Diagonal( exp.(-A_eigvals .* tau) )
    # rotate using eigvecs
    exp_term = (A_eigvecs * exp_eigvals) / A_eigvecs
    return real.(exp_term)
end


function calc_exp_matrix_tau_integral(
        tau::Float64,
        A_eigvals::AbstractVector,
        A_eigvecs::AbstractMatrix
    )::Matrix{Float64}
    # int_0^tau exp(-A*s) ds given eigvals and eigvecs of A
    # diag of (1 - exp(-lami*tau))/lami (uses Diagonal matrix type for faster computations)
    I_exp_eigvals = Diagonal( -expm1.(-A_eigvals.*tau) ./ A_eigvals )
    # rotate using eigvecs
    int_term = (A_eigvecs * I_exp_eigvals) / A_eigvecs
    return real.(int_term)
end


############################################################################################
# methods for computing directional derivatives of exp matrix and eigenvalues efficiently

# simple eigenvalue deriv
function deriv_simple_eigval(
        u0::Vector{T},
        v0::Vector{T},
        dZ::AbstractMatrix
    )::T where T<:Number
    vu_prod = dot( v0, u0 )
    dlam = dot(v0, dZ * u0) / vu_prod
    return dlam
end
# wrapper for deriv wrt Z_{j,k}
function deriv_simple_eigval(
        u0::Vector{T},
        v0::Vector{T},
        j_idx::Int,
        k_idx::Int
    )::T where T<:Number
    vu_prod = dot( v0, u0 )
    dlam = (v0[j_idx] * u0[k_idx]) / vu_prod
    return dlam
end

function deriv_simple_eigval_eigvec(
        lam0::T,
        u0::Vector{T},
        v0::Vector{T},
        Z0::Matrix{Float64},
        dZ::AbstractMatrix
    )::Tuple{T, Vector{T}} where T<:Number
    vu_prod = dot( v0, u0 )
    dlam = dot(v0, dZ * u0) / vu_prod
    lamZ_pinv = pinv(lam0.*I - Z0)
    I_uv = I - ( u0 * v0' )./vu_prod
    du = lamZ_pinv * I_uv * dZ * u0
    return dlam, du
end

# derivs of all eigvals/vecs
function deriv_simple_eigdecomp(
        Lam_diag0::Vector{T},
        U0::Matrix{T},
        Z0::Matrix{Float64},
        dZ::AbstractMatrix
    )::Tuple{Vector{T}, Matrix{T}} where T<:Number
    N = size(U0,1)
    V0 = inv(U0)'
    dLam_diag = zero(Lam_diag0)
    dU = zero(U0)
    # derivs for each eigenvalue/vector
    for j=1:N
        lam0 = Lam_diag0[j]
        u0 = U0[:,j]
        v0 = V0[:,j]
        dlam, du = deriv_simple_eigval_eigvec(lam0, u0, v0, Z0, dZ )
        dLam_diag[j] = dlam
        dU[:,j] = du
    end
    return dLam_diag, dU
end


############################################################################################
# note:
# for computing derivs of exp matrices, must compute auxiliary Phi matrix
# methods below compute this inplace
# use AbstractMatrix to allow for passing SubArray (views) to avoid allocations

# exp(-A*tau)
function calc_dexp_Phi_mat!(Phi_mat::AbstractMatrix, tau::Float64,
        A_eigvals::AbstractVector)::Nothing
    # construct Phi matrix
    J = length(A_eigvals)
    et_Lam = exp.(-tau .* A_eigvals)::Vector{ComplexF64}
    for i=1:J
        lam_i = A_eigvals[i]::ComplexF64
        et_lam_i = et_Lam[i]::ComplexF64
        for j=1:J
            lam_j = A_eigvals[j]::ComplexF64
            et_lam_j = et_Lam[j]::ComplexF64
            if lam_i==lam_j
                phi_ij = -tau*et_lam_j
            else
                phi_ij = (et_lam_i - et_lam_j)/(lam_i-lam_j)
            end
            Phi_mat[i,j] = phi_ij
        end
    end
    return nothing
end

# int_0^tau exp(-A*s) ds
function calc_dexp_int_Phi_mat!(int_Phi_mat::AbstractMatrix, tau::Float64,
        A_eigvals::AbstractVector)::Nothing
    # compute integral of Phi matrix
    J = size(A_eigvals, 1)
    et_Lam = exp.(-tau.*A_eigvals)::Vector{ComplexF64}
    et_Lam_ratio = ( -expm1.(-tau.*A_eigvals) ./ A_eigvals )::Vector{ComplexF64}
    for i=1:J
        lam_i = A_eigvals[i]::ComplexF64
        et_lam_i = et_Lam[i]::ComplexF64
        et_lam_i_ratio = et_Lam_ratio[i]::ComplexF64
        for j=1:J
            lam_j = A_eigvals[j]::ComplexF64
            et_lam_j = et_Lam[j]::ComplexF64
            et_lam_j_ratio = et_Lam_ratio[j]::ComplexF64
            if lam_i==lam_j
                phi_ij = -et_lam_j_ratio/lam_j + tau*et_lam_j/lam_j
            else
                phi_ij = (et_lam_i_ratio - et_lam_j_ratio)/(lam_i-lam_j)
            end
            int_Phi_mat[i,j] = phi_ij
        end
    end
    return nothing
end


# taking as given Phi/int_Phi mat
function deriv_exp_matrix_tau(Phi_mat::AbstractMatrix{ComplexF64},
        A_eigvecs::AbstractMatrix{ComplexF64}, dA::AbstractMatrix{Float64})::Matrix{Float64}
    # deriv
    V_bar = ( A_eigvecs \ (dA * A_eigvecs) )::Matrix{ComplexF64}
    # hadamard prod
    VbPhi = ( V_bar .* Phi_mat )::Matrix{ComplexF64}
    return real.((A_eigvecs * VbPhi) / A_eigvecs)
end

function deriv_exp_matrix_tau(Phi_mat::AbstractMatrix{ComplexF64},
        A_eigvecs::AbstractMatrix{ComplexF64}, A_eigvecs_inv::AbstractMatrix{ComplexF64},
        j_idx::Int, k_idx::Int)::Matrix{Float64}
    # deriv, given dA = ej ek^T
    V_bar = ( A_eigvecs_inv[:, j_idx] * transpose(A_eigvecs[k_idx, :]) )::Matrix{ComplexF64}
    # hadamard prod
    VbPhi = ( V_bar .* Phi_mat )::Matrix{ComplexF64}
    return real.((A_eigvecs * VbPhi) * A_eigvecs_inv)
end


##################################################################################################
# from latest version of LinearAlgebra
# can remove once added to the main std lib
# https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/qr.jl

using LinearAlgebra
LinearAlgebra.det(Q::LinearAlgebra.QRPackedQ) = _det_tau(Q.τ)

LinearAlgebra.det(Q::LinearAlgebra.QRCompactWYQ) =
    prod(i -> _det_tau(_diagview(Q.T[:, i:min(i + size(Q.T, 1), size(Q.T, 2))])),
         1:size(Q.T, 1):size(Q.T, 2))

_diagview(A) = @view A[diagind(A)]

# Compute `det` from the number of Householder reflections.  Handle
# the case `Q.τ` contains zeros.
_det_tau(τs::AbstractVector{<:Real}) =
    isodd(count(!iszero, τs)) ? -one(eltype(τs)) : one(eltype(τs))

# In complex case, we need to compute the non-unit eigenvalue `λ = 1 - c*τ`
# (where `c = v'v`) of each Householder reflector.  As we know that the
# reflector must have the determinant of 1, it must satisfy `abs2(λ) == 1`.
# Combining this with the constraint `c > 0`, it turns out that the eigenvalue
# (hence the determinant) can be computed as `λ = -sign(τ)^2`.
# See: https://github.com/JuliaLang/julia/pull/32887#issuecomment-521935716
_det_tau(τs) = prod(τ -> iszero(τ) ? one(τ) : -sign(τ)^2, τs)
##################################################################################################
